import math
import numpy as np

def softmax_np(x):
    x = np.array(x)
    soft_x = []
    for x_t in x:
        x_max = np.max(x_t)
        x_t = x_t - x_max
        x_exp = np.exp(x_t)
        x_exp_sum = np.sum(x_exp)
        soft_x_t = x_exp / x_exp_sum
        soft_x.append(soft_x_t)
    soft_x = np.array(soft_x)
    return soft_x

def perplexity(testset, topic_word, doc_topic):
    """calculate the perplexity of a lda-model"""

    print('doc size =',len(testset))
    print('vocab size =',len(testset[0]))
    print('topic num =',len(topic_word))
    
    num_topics = len(topic_word)
    
    #print('pre topic_word[0]=',topic_word[0])
    topic_word = softmax_np(topic_word)
    #print('softmax topic_word[0]=',topic_word[0])
    
    #print('pre doc_topic[0]=',doc_topic[0])
    doc_topic = softmax_np(doc_topic)
    #print('softmax doc_topic[0]=',doc_topic[0])
    #print('sum=',np.sum(doc_topic[0]))
    
    doc_word_prob = doc_topic.dot(topic_word)

    prep = 0.0
    prob_doc_sum = 0.0
    
    testset_word_num = np.sum(testset)
    for i in range(len(testset)):
        prob_doc = 0.0 # the probablity of the doc
        doc = testset[i]
        for word_id, num in enumerate(doc):
            if num == 0:
                continue
            # cal p(w) : p(w) = sumz(p(z)*p(w|z))
            prob_word = doc_word_prob[i][word_id] # the probablity of the word 
            prob_doc += math.log(prob_word)*num # p(d) = sum(log(p(w)))
        prob_doc_sum += prob_doc
    
    #print(-prob_doc_sum/testset_word_num)
    prep = math.exp(-prob_doc_sum/testset_word_num) # perplexity = exp(-sum(p(d)/sum(Nd))

    return prep


def PrecisionAtR(doc_topic, doc_class, R):
    doc_num = len(doc_topic)

    doc_d = doc_topic.dot(doc_topic.transpose())

    doc_l = np.zeros(doc_num)
    for i, doc_vec in enumerate(doc_topic):
        doc_l[i] = math.sqrt(np.sum(doc_vec**2))

    for i in range(doc_num):
        for j in range(doc_num):
            doc_d[i][j] /= (doc_l[i]*doc_l[j])
        doc_d[i][i] = 0

    if R < 1:
        R = int(doc_num*R)

    total_p = 0.
    for i in range(doc_num):
        top_R = np.argsort(-doc_d[i])[:R]
        p_t = 0.
        for did in top_R:
            #print(doc_d[i][did])
            if doc_class[did] == doc_class[i]:
                p_t += 1
        total_p += p_t/R

    return total_p/doc_num